{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6eddab55",
   "metadata": {},
   "source": [
    "# Lesson 6 - Multi-LoRA\n",
    "\n",
    "\n",
    "In this lesson, we'll see how to efficiently serve dozens of fine-tuned models together in a single deployment without sacrificing latency."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1501fea0",
   "metadata": {},
   "source": [
    "### Import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7964cc5-0d8d-4ff1-9cdc-89032e45fa6e",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b62bb64",
   "metadata": {},
   "source": [
    "### Let's create a new model\n",
    "\n",
    "We will start with creating an extension to the model from lesson 5. It has a custom helper function for computing the LoRA layer step with multiple LoRAs per batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecf12393-deae-4daa-a93a-0b7faa3ad3ff",
   "metadata": {
    "height": 522
   },
   "outputs": [],
   "source": [
    "class AbstractMultiLoraModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "        # hidden_size = 10\n",
    "        # set this so low to ensure we are not \n",
    "        # compute-bound by the linear layer\n",
    "        # this is only an issue when running on CPU, \n",
    "        # for GPUs we can set this much\n",
    "        # higher and still avoid being compute bound\n",
    "        self.embedding = torch.nn.Embedding(10, 10)\n",
    "        self.linear = torch.nn.Linear(10, 10)\n",
    "        self.lm_head = torch.nn.Linear(10, 10)\n",
    "\n",
    "    def linear_lora(\n",
    "        self,\n",
    "        x: torch.Tensor,                 # (batch_size, seq_len, in_features)\n",
    "        loras_a: torch.Tensor,           # (num_loras, in_features, rank)\n",
    "        loras_b: torch.Tensor,           # (num_loras, rank, out_features)\n",
    "        lora_indices: torch.LongTensor,  # (batch_size,)\n",
    "    ) -> torch.Tensor:\n",
    "        # y[i] = x[i] @ loras_a[lora_idx] @ loras_b[lora_idx]\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def forward(self, input_ids, loras_a, loras_b, lora_indices):\n",
    "        x = self.embedding(input_ids)\n",
    "        x = self.linear_lora(x, loras_a, loras_b, lora_indices)\n",
    "        x = self.lm_head(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae0ebbf",
   "metadata": {},
   "source": [
    "### Using a loop\n",
    "\n",
    "Our first attempt to infer across multiple LoRAs will be straightforward: just loop over every row in the batch and apply the correct LoRA using an index mapping: `batch_index --> lora_index`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94861153-8bfe-4ae5-9b31-fa7f78e23793",
   "metadata": {
    "height": 267
   },
   "outputs": [],
   "source": [
    "class LoopMultiLoraModel(AbstractMultiLoraModel):\n",
    "    def linear_lora(\n",
    "        self,\n",
    "        x: torch.Tensor,                 # (batch_size, seq_len, in_features)\n",
    "        loras_a: torch.Tensor,           # (num_loras, in_features, lora_rank)\n",
    "        loras_b: torch.Tensor,           # (num_loras, lora_rank, out_features)\n",
    "        lora_indices: torch.LongTensor,  # (batch_size,)\n",
    "    ) -> torch.Tensor:\n",
    "        y = self.linear(x)\n",
    "        for batch_idx, lora_idx in enumerate(lora_indices.numpy()):\n",
    "            lora_a = loras_a[lora_idx]\n",
    "            lora_b = loras_b[lora_idx]\n",
    "            y[batch_idx] += x[batch_idx] @ lora_a @ lora_b\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a75c87fe-41e3-45ba-aeec-6e00b48e9d53",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "# toy example of a detokenizer. The vocabular only consists of 10 words (different colors)\n",
    "detokenizer = [\n",
    "    \"red\",\n",
    "    \"orange\",\n",
    "    \"yellow\",\n",
    "    \"green\",\n",
    "    \"blue\",\n",
    "    \"indigo\",\n",
    "    \"violet\",\n",
    "    \"magenta\",\n",
    "    \"marigold\",\n",
    "    \"chartreuse\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e45b3e3-df57-4da1-af74-6658d677b9a1",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# dummy inputs\n",
    "input_ids = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b1f7a87-f359-40d9-b5f1-e516c855f6bb",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecd7fc8d-be71-416b-888d-391714dd8eeb",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "def generate_token(model, **kwargs):\n",
    "    with torch.no_grad():\n",
    "        logits = model(**kwargs)\n",
    "    last_logits = logits[:, -1, :]\n",
    "    next_token_ids = last_logits.argmax(dim=1)\n",
    "\n",
    "    return [detokenizer[token_id] for token_id in next_token_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a30f863f-36b9-4f2c-9a06-33a3c5ef565c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model = LoopMultiLoraModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecbbe14c",
   "metadata": {},
   "source": [
    "### Let's try it!\n",
    "\n",
    "We will try this over a few random LoRAs using a fixed tensor of input_ids. If our multi-LoRA generation process is working as designed, we should see a variety of different outputs generated as we randomly iterate over the LoRAs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "818f27ef-9e6e-4aa5-bebd-06c2ffa7e711",
   "metadata": {
    "height": 386
   },
   "outputs": [],
   "source": [
    "# constants\n",
    "bs = 1\n",
    "num_loras = 64\n",
    "h = 10\n",
    "r = 2\n",
    "\n",
    "# create contiguous blocks for 64 random LoRA weights\n",
    "loras_a = torch.randn(num_loras, h, r)\n",
    "loras_b = torch.randn(num_loras, r, h)\n",
    "\n",
    "for i in range(10):\n",
    "    # randomize the LoRAs each iteration\n",
    "    lora_indices = torch.randint(num_loras, (bs,), dtype=torch.long)\n",
    "    next_token = generate_token(\n",
    "        model,\n",
    "        input_ids=input_ids,\n",
    "        loras_a=loras_a,\n",
    "        loras_b=loras_b,\n",
    "        lora_indices=lora_indices,\n",
    "    )\n",
    "    print(next_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1108cbf5",
   "metadata": {},
   "source": [
    "### Let's benchmark our multi-LoRA system!\n",
    "\n",
    "We will measure the average latency to generate a single token as the batch size increases and each element within the batch can have a different LoRA adapter (chosen randomly)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "798ce936-fbad-4a49-8c86-2c2b4ec3747b",
   "metadata": {
    "height": 607
   },
   "outputs": [],
   "source": [
    "# constants\n",
    "seq_len = 8\n",
    "vocab_size = 10\n",
    "nsamples = 500\n",
    "max_batch_size = 64\n",
    "\n",
    "\n",
    "def benchmark(model):\n",
    "    avg_latencies = []\n",
    "    for bs in range(1, max_batch_size + 1):\n",
    "        latencies = []\n",
    "        for _ in range(nsamples):\n",
    "            # randomize the inputs and LoRA indices\n",
    "            input_ids = torch.randint(\n",
    "                vocab_size, (bs, seq_len), dtype=torch.long)\n",
    "            lora_indices = torch.randint(\n",
    "                num_loras, (bs,), dtype=torch.long)\n",
    "\n",
    "            # measure the end-to-end latency for \n",
    "            # generating a single token\n",
    "            t0 = time.time()\n",
    "            next_token = generate_token(\n",
    "                model,\n",
    "                input_ids=input_ids,\n",
    "                loras_a=loras_a,\n",
    "                loras_b=loras_b,\n",
    "                lora_indices=lora_indices,\n",
    "            )\n",
    "            latencies.append(time.time() - t0)\n",
    "\n",
    "        # average the latency across all the samples\n",
    "        latency_s = sum(latencies) / len(latencies)\n",
    "        avg_latencies.append(latency_s)\n",
    "        print(bs, latency_s)\n",
    "    return avg_latencies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1c960d4",
   "metadata": {},
   "source": [
    "**Note:** Your results might differ from those shown in the video, but they will still follow the same pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d1f038-a3a9-4e4f-be18-36af0a31572a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "avg_latencies_loop = benchmark(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e88a95a",
   "metadata": {},
   "source": [
    "### Let's visualize it!\n",
    "\n",
    "**Note**: Your plot may vary slightly from the one shown in the video, yet it will exhibit a similar pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf2c5122-37da-4a58-8e1a-ca211dc2551f",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "x = list(range(1, max_batch_size + 1))\n",
    "plt.plot(x, avg_latencies_loop, label=\"loop\")\n",
    "\n",
    "plt.xlabel('Batch Size')\n",
    "plt.ylabel('Avg Latency (s)')\n",
    "plt.title('Multi-LoRA latency w.r.t. batch size')\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4adf9077",
   "metadata": {},
   "source": [
    "### Let's vectorize the LoRA computation\n",
    "\n",
    "We will vectorize the LoRA computation by:\n",
    "\n",
    "1. Gather the LoRA weight for each batch into a single tensor using `torch.index_select`.\n",
    "2. Apply LoRA computation once for the entire input tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5801f273-741e-4466-bb63-b8913562723b",
   "metadata": {
    "height": 284
   },
   "outputs": [],
   "source": [
    "class GatheredMultiLoraModel(AbstractMultiLoraModel):\n",
    "    def linear_lora(\n",
    "        self,\n",
    "        x: torch.Tensor,                 # (batch_size, seq_len, in_features)\n",
    "        loras_a: torch.Tensor,           # (num_loras, in_features, lora_rank)\n",
    "        loras_b: torch.Tensor,           # (num_loras, lora_rank, out_features)\n",
    "        lora_indices: torch.LongTensor,  # (batch_size,)\n",
    "    ) -> torch.Tensor:\n",
    "        y = self.linear(x)\n",
    "        \n",
    "        # gather the LoRA weights into a new tensor and apply\n",
    "        lora_a = torch.index_select(loras_a, 0, lora_indices) # (batch_size, in_features, lora_rank)\n",
    "        lora_b = torch.index_select(loras_b, 0, lora_indices) # (batch_size, lora_rank, out_features)\n",
    "        y += x @ lora_a @ lora_b\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8aafb6e-6981-4c83-9f9d-60d2e8958746",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model = GatheredMultiLoraModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f6ae3db",
   "metadata": {},
   "source": [
    "**Note:** Your results might differ from those shown in the video, but they will still follow the same pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3843d58-4fd5-4141-a73a-69d1d01db89d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "avg_latencies_gathered = benchmark(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7603255",
   "metadata": {},
   "source": [
    "### Let's visualize it!\n",
    "\n",
    "**Note**: Your plot may vary slightly from the one shown in the video, yet it will exhibit a similar pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee3360b-a610-4c48-91d1-09d33647f1aa",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "x = list(range(1, max_batch_size + 1))\n",
    "plt.plot(x, avg_latencies_loop, label=\"loop\")\n",
    "plt.plot(x, avg_latencies_gathered, label=\"gathered\")\n",
    "\n",
    "plt.xlabel('Batch Size')\n",
    "plt.ylabel('Avg Latency (s)')\n",
    "plt.title('Multi-LoRA latency w.r.t. batch size')\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
